{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QJ5krjrcbuiA"
   },
   "source": [
    "# Notes on Generalized Advantage Estimation\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UPVl5LBEWJ0t"
   },
   "source": [
    "## How to compute GAE on your own?\n",
    "(Note that for this reading you need to understand the calculation of [GAE](https://arxiv.org/abs/1506.02438) advantage first)\n",
    "\n",
    "In terms of code implementation, perhaps the most difficult and annoying part is computing GAE advantage. Just now, we use the `self.compute_episodic_return()` method inherited from `BasePolicy` to save us from all those troubles. However, it is still important that we know the details behind this.\n",
    "\n",
    "To compute GAE advantage, the usage of `self.compute_episodic_return()` may go like:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "D34GlVvPNz08",
    "outputId": "43a4e5df-59b5-4e4a-c61c-e69090810215"
   },
   "source": [
    "```python\n",
    "batch, indices = dummy_buffer.sample(0)  # 0 means sampling all the data from the buffer\n",
    "returns, advantage = Algorithm.compute_episodic_return(\n",
    "    batch=batch,\n",
    "    buffer=dummy_buffer,\n",
    "    indices=indices,\n",
    "    v_s_=np.zeros(10),\n",
    "    v_s=np.zeros(10),\n",
    "    gamma=1.0,\n",
    "    gae_lambda=1.0,\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the code above, we sample all the 10 data in the buffer and try to compute the GAE advantage. However, the way the returns are computed here might be a bit misleading. In fact, the last episode is unfinished, but its last step saved in the batch is treated as a terminal state, since it assumes that there are no future rewards. The episode is not terminated yet, it is truncated, so the agent could still get rewards in the future. Terminated and truncated episodes should indeed be treated differently.\n",
    "The return of a step is the (discounted) sum of the future rewards from that step until the end of the episode. \n",
    "\\begin{equation}\n",
    "R_{t}=\\sum_{t}^{T} \\gamma^{t} r_{t}\n",
    "\\end{equation}\n",
    "Thus, at the last step of a terminated episode the return is equal to the reward at that state, since there are no future states.\n",
    "\\begin{equation}\n",
    "R_{T,terminated}=r_{T}\n",
    "\\end{equation}\n",
    "\n",
    "However, if the episode was truncated the return at the last step is usually better represented by the estimated value of that state, which is the expected return from that state onwards.\n",
    "\\begin{align*}\n",
    "R_{T,truncated}=V^{\\pi}\\left(s_{T}\\right)  \\quad & \\text{or} \\quad R_{T,truncated}=Q^{\\pi}(s_{T},a_{T})\n",
    "\\end{align*}\n",
    "Moreover, if the next state was also observed (but not its reward), then an even better estimate would be the reward of the last step plus the discounted value of the next state.\n",
    "\\begin{align*}\n",
    "R_{T,truncated}=r_T+\\gamma V^{\\pi}\\left(s_{T+1}\\right)\n",
    "\\end{align*}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "h_5Dt6XwQLXV"
   },
   "source": [
    "\n",
    "As we know, we need to estimate the value function of every observation to compute GAE advantage. So in `v_s` is the value of `batch.obs`, and in `v_s_` is the value of `batch.obs_next`. This is usually computed by:\n",
    "\n",
    "`v_s = critic(batch.obs)`,\n",
    "\n",
    "`v_s_ = critic(batch.obs_next)`,\n",
    "\n",
    "where both `v_s` and `v_s_` are 10 dimensional arrays and `critic` is usually a neural network.\n",
    "\n",
    "After we've got all those values, GAE can be computed following the equation below."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ooHNIICGUO19"
   },
   "source": [
    "\\begin{aligned}\n",
    "\\hat{A}_{t}^{\\mathrm{GAE}(\\gamma, \\lambda)}: =& \\sum_{l=0}^{\\infty}(\\gamma \\lambda)^{l} \\delta_{t+l}^{V}\n",
    "\\end{aligned}\n",
    "\n",
    "where\n",
    "\n",
    "\\begin{equation}\n",
    "\\delta_{t}^{V} \\quad=-V\\left(s_{t}\\right)+r_{t}+\\gamma V\\left(s_{t+1}\\right)\n",
    "\\end{equation}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "eV6XZaouU7EV"
   },
   "source": [
    "Unfortunately, if you  follow this equation, which is taken from the paper, you probably will get a slightly lower performance than you expected. There are at least 3 \"bugs\" in this equation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FCxD9gNNVYbd"
   },
   "source": [
    "**First** is that Gym always returns you a `obs_next` even if this is already the last step. The value of this timestep is exactly 0 and you should not let the neural network estimate it."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "rNZNUNgQVvRJ",
    "outputId": "44354595-c25a-4da8-b4d8-cffa31ac4b7d"
   },
   "source": [
    "```python\n",
    "# Assume v_s_ is got by calling critic(batch.obs_next)\n",
    "v_s_ = np.ones(10)\n",
    "v_s_ *= ~batch.done\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2EtMi18QWXTN"
   },
   "source": [
    "After the fix above, we will perhaps get a more accurate estimate.\n",
    "\n",
    "**Secondly**, you must know when to stop bootstrapping. Usually we stop bootstrapping when we meet a `done` flag. However, in the buffer above, the last (10th) step is not marked by done=True, because the collecting has not finished. We must know all those unfinished steps so that we know when to stop bootstrapping.\n",
    "\n",
    "Luckily, this can be done under the assistance of buffer because buffers in Tianshou not only store data, but also help you manage data trajectories."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "saluvX4JU6bC",
    "outputId": "2994d178-2f33-40a0-a6e4-067916b0b5c5"
   },
   "source": [
    "```python\n",
    "unfinished_indexes = dummy_buffer.unfinished_index()\n",
    "done_indexes = np.where(batch.done)[0]\n",
    "stop_bootstrap_ids = np.concatenate([unfinished_indexes, done_indexes])\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qp6vVE4dYWv1"
   },
   "source": [
    "**Thirdly**, there are some special indexes which are marked by done flag, however its value for obs_next should not be zero. It is again because done does not differentiate between terminated and truncated. These steps are usually those at the last step of an episode, but this episode stops not because the agent can no longer get any rewards (value=0), but because the episode is too long so we have to truncate it. These kind of steps are always marked with `info['TimeLimit.truncated']=True` in Gym."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tWkqXRJfZTvV"
   },
   "source": [
    "As a result, we need to rewrite the equation above\n",
    "\n",
    "`v_s_ *= ~batch.done`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kms-QtxKZe-M"
   },
   "source": [
    "to\n",
    "\n",
    "```\n",
    "mask = batch.info['TimeLimit.truncated'] | (~batch.done)\n",
    "v_s_ *= mask\n",
    "\n",
    "```\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "u_aPPoKraBu6"
   },
   "source": [
    "## Summary\n",
    "If you already felt bored by now, simply remember that Tianshou can help handle all these little details so that you can focus on the algorithm itself. Just call `Algorithm.compute_episodic_return()`.\n",
    "\n",
    "If you still feel interested, we would recommend you check Appendix C in this [paper](https://arxiv.org/abs/2107.14171v2) and implementation of `Algorithm.value_mask()` and `Algorithm.compute_episodic_return()` for details."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2cPnUXRBWKD9"
   },
   "source": [
    "<center>\n",
    "<img src=../_static/images/timelimit.svg></img>\n",
    "</center>\n",
    "<center>\n",
    "<img src=../_static/images/policy_table.svg></img>\n",
    "</center>"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
